import os
import json
import logging
import torch
import torch.nn as nn
import deepspeed
from tensorboardX import SummaryWriter

from data.data import get_data
from model.model import get_model

def load_config(config_path):
    """加载配置文件"""
    with open(config_path, "r") as f:
        config = json.load(f)
    return config

def setup_logger(log_dir, stage):
    """设置日志记录器"""
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    log_file = os.path.join(log_dir, f'stage_{stage}.log')
    logger = logging.getLogger(f'stage_{stage}')
    logger.setLevel(logging.INFO)
    
    # 创建文件处理器
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    
    # 创建控制台处理器
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    
    # 设置日志格式
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # 添加处理器到日志记录器
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    return logger

def train(model, train_loader, criterion, epoch, writer, logger):
    """训练函数"""
    model.train()
    total_loss = 0.0
    
    for i, (images, labels) in enumerate(train_loader):
        # 将数据移动到当前 GPU
        images = images.to(model.local_rank).half()
        labels = labels.to(model.local_rank)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        model.backward(loss)
        model.step()

        # 记录损失
        total_loss += loss.item()
        if i % 10 == 0 and model.local_rank == 0:  # 只有主进程输出
            logger.info(f"Epoch [{epoch+1}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    # 打印每个 epoch 的平均损失
    avg_loss = total_loss / len(train_loader)
    if model.local_rank == 0:  # 只有主进程输出
        logger.info(f"Epoch [{epoch+1}] Training Loss: {avg_loss:.4f}")
        writer.add_scalar('Train/Loss', avg_loss, epoch)

def val(model, val_loader, criterion, epoch, writer, logger):
    """验证函数"""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            # 将数据移动到当前 GPU
            images = images.to(model.local_rank).half()
            labels = labels.to(model.local_rank)

            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 记录损失和准确率
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    # 计算平均损失和准确率
    avg_loss = total_loss / len(val_loader)
    accuracy = correct / total

    if model.local_rank == 0:  # 只有主进程输出
        logger.info(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
        writer.add_scalar('Val/Loss', avg_loss, epoch)
        writer.add_scalar('Val/Accuracy', accuracy, epoch)

    return avg_loss, accuracy

def main(data_path: str, batch_size: int, model_name: str, stage: int):
    """主函数"""
    # 加载配置文件
    config = load_config('./config/config.json')
    
    # 更新 DeepSpeed 配置中的 ZeRO 阶段
    config['deepspeed_config']['zero_optimization']['stage'] = stage

    # 初始化 TensorBoard，日志路径包含 ZeRO 阶段信息
    writer = SummaryWriter(f'./scalar/DeepSpeed-stage{stage}')

    # 设置日志记录器
    logger = setup_logger('./logs', stage)
    logger.info(f"Starting training with ZeRO stage {stage}...")

    # 加载数据和模型
    train_dataset, val_dataset = get_data(data_path=data_path, batch_size=batch_size)
    model = get_model(model_name, classes=50)

    # 损失函数
    criterion = nn.CrossEntropyLoss()

    # 初始化 DeepSpeed
    model_engine, optimizer, train_loader, _ = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        training_data=train_dataset,
        config=config['deepspeed_config']
    )

    # 获取验证集数据加载器
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    # 初始化最佳验证损失和准确率
    best_val_loss = float('inf')
    best_val_acc = 0.0

    # 训练和验证循环
    for epoch in range(config['train_config']['epochs']):
        train(model_engine, train_loader, criterion, epoch, writer, logger)
        val_loss, val_acc = val(model_engine, val_loader, criterion, epoch, writer, logger)

        # 保存最佳模型权重
        if model_engine.local_rank == 0:  # 只有主进程保存模型
            if val_loss < best_val_loss:  # 根据验证损失保存
                best_val_loss = val_loss
                torch.save(model_engine.state_dict(), f'./scalar/DeepSpeed-stage{stage}/best_model_stage_{stage}_loss.pth')
                logger.info(f"Saved best model with validation loss: {best_val_loss:.4f}")

            if val_acc > best_val_acc:  # 根据验证准确率保存
                best_val_acc = val_acc
                torch.save(model_engine.state_dict(), f'./scalar/DeepSpeed-stage{stage}/best_model_stage_{stage}_acc.pth')
                logger.info(f"Saved best model with validation accuracy: {best_val_acc * 100:.2f}%")

            # 记录显存使用情况
            memory_allocated = torch.cuda.memory_allocated(model_engine.local_rank) / 1024 ** 2
            memory_reserved = torch.cuda.memory_reserved(model_engine.local_rank) / 1024 ** 2
            writer.add_scalar('Memory/Allocated_MB', memory_allocated, epoch)
            writer.add_scalar('Memory/Reserved_MB', memory_reserved, epoch)

    # 关闭 TensorBoard writer
    if model_engine.local_rank == 0:
        writer.close()
        logger.info(f"Training with ZeRO stage {stage} completed.")

if __name__ == '__main__':
    # 加载配置文件
    config = load_config('./config/config.json')

    # 测试 ZeRO 阶段 0、1、2
    for stage in range(3):
        main(
            data_path=config['train_config']['data_path'],
            batch_size=config['deepspeed_config']['train_batch_size'],
            model_name=config['train_config']['model_name'],
            stage=stage
        )